"""Factor performance evaluation utilities."""
from datetime import date
from typing import Dict, List, Optional, Sequence, Tuple

import numpy as np
from scipy import stats

from app.features.factors import (
    DEFAULT_FACTORS,
    FactorSpec,
    lookup_factor_spec,
)
from app.utils.db import db_session
from app.utils.logging import get_logger

LOGGER = get_logger(__name__)
LOG_EXTRA = {"stage": "factor_evaluation"}


class FactorPerformance:
    """因子表现评估结果。"""
    
    def __init__(self, factor_name: str) -> None:
        self.factor_name = factor_name
        self.ic_series: List[float] = []
        self.rank_ic_series: List[float] = []
        self.return_spreads: List[float] = []
        self.sharpe_ratio: Optional[float] = None
        self.turnover_rate: Optional[float] = None
        self.sample_size: int = 0
        
    @property
    def ic_mean(self) -> float:
        """平均IC。"""
        return np.mean(self.ic_series) if self.ic_series else 0.0
        
    @property
    def ic_std(self) -> float:
        """IC标准差。"""
        return np.std(self.ic_series) if self.ic_series else 0.0
        
    @property
    def ic_ir(self) -> float:
        """信息比率。"""
        return self.ic_mean / self.ic_std if self.ic_std > 0 else 0.0
        
    @property
    def rank_ic_mean(self) -> float:
        """平均RankIC。"""
        return np.mean(self.rank_ic_series) if self.rank_ic_series else 0.0
        
    def to_dict(self) -> Dict[str, float]:
        """转换为字典格式。"""
        return {
            "ic_mean": self.ic_mean,
            "ic_std": self.ic_std,
            "ic_ir": self.ic_ir,
            "rank_ic_mean": self.rank_ic_mean,
            "sharpe_ratio": self.sharpe_ratio or 0.0,
            "turnover_rate": self.turnover_rate or 0.0,
            "sample_size": float(self.sample_size),
        }


def evaluate_factor(
    factor_name: str,
    start_date: date,
    end_date: date,
    universe: Optional[List[str]] = None,
) -> FactorPerformance:
    """评估单个因子的预测能力。
    
    Args:
        factor_name: 因子名称
        start_date: 起始日期
        end_date: 结束日期
        universe: 可选的股票池
        
    Returns:
        因子表现评估结果
    """
    performance = FactorPerformance(factor_name)
    spec = lookup_factor_spec(factor_name)
    factor_column = factor_name

    if spec is None:
        LOGGER.warning("未找到因子定义，仍尝试从数据库读取 factor=%s", factor_name, extra=LOG_EXTRA)

    normalized_universe = _normalize_universe(universe)
    start_str = start_date.strftime("%Y%m%d")
    end_str = end_date.strftime("%Y%m%d")

    with db_session(read_only=True) as conn:
        if not _has_factor_column(conn, factor_column):
            LOGGER.warning("factors 表缺少列 %s，跳过评估", factor_column, extra=LOG_EXTRA)
            return performance
        trade_dates = _list_factor_dates(conn, start_str, end_str, normalized_universe)

    if not trade_dates:
        LOGGER.info("指定区间内未找到可用因子数据 factor=%s", factor_name, extra=LOG_EXTRA)
        return performance

    usable_trade_dates: List[str] = []

    for trade_date_str in trade_dates:
        with db_session(read_only=True) as conn:
            factor_map = _fetch_factor_cross_section(conn, factor_column, trade_date_str, normalized_universe)
            if not factor_map:
                continue
            next_trade = _next_trade_date(conn, trade_date_str)
            if not next_trade:
                continue
            curr_close = _fetch_close_map(conn, trade_date_str, factor_map.keys())
            next_close = _fetch_close_map(conn, next_trade, factor_map.keys())

        factor_values: List[float] = []
        returns: List[float] = []
        for ts_code, value in factor_map.items():
            curr = curr_close.get(ts_code)
            nxt = next_close.get(ts_code)
            if curr is None or nxt is None or curr <= 0:
                continue
            factor_values.append(value)
            returns.append((nxt - curr) / curr)

        if len(factor_values) < 20:
            continue

        values_array = np.array(factor_values, dtype=float)
        returns_array = np.array(returns, dtype=float)
        if np.ptp(values_array) <= 1e-9 or np.ptp(returns_array) <= 1e-9:
            LOGGER.debug(
                "因子/收益序列波动不足，跳过 date=%s span_factor=%.6f span_return=%.6f",
                trade_date_str,
                float(np.ptp(values_array)),
                float(np.ptp(returns_array)),
                extra=LOG_EXTRA,
            )
            continue

        try:
            ic, _ = stats.pearsonr(values_array, returns_array)
            rank_ic, _ = stats.spearmanr(values_array, returns_array)
        except Exception as exc:  # noqa: BLE001
            LOGGER.debug("IC 计算失败 date=%s err=%s", trade_date_str, exc, extra=LOG_EXTRA)
            continue

        if not (np.isfinite(ic) and np.isfinite(rank_ic)):
            LOGGER.debug(
                "相关系数结果无效 date=%s ic=%s rank_ic=%s",
                trade_date_str,
                ic,
                rank_ic,
                extra=LOG_EXTRA,
            )
            continue

        performance.ic_series.append(ic)
        performance.rank_ic_series.append(rank_ic)
        usable_trade_dates.append(trade_date_str)

        sorted_pairs = sorted(zip(values_array.tolist(), returns_array.tolist()), key=lambda item: item[0])
        quantile = len(sorted_pairs) // 5
        if quantile > 0:
            top_returns = [ret for _, ret in sorted_pairs[-quantile:]]
            bottom_returns = [ret for _, ret in sorted_pairs[:quantile]]
            spread = float(np.mean(top_returns) - np.mean(bottom_returns))
            performance.return_spreads.append(spread)

    if performance.return_spreads:
        returns_mean = float(np.mean(performance.return_spreads))
        returns_std = float(np.std(performance.return_spreads))
        if returns_std > 0:
            performance.sharpe_ratio = returns_mean / returns_std * np.sqrt(252.0)

    performance.sample_size = len(usable_trade_dates)
    performance.turnover_rate = _estimate_turnover_rate(
        factor_column,
        usable_trade_dates,
        normalized_universe,
    )
    return performance


def combine_factors(
    factor_names: Sequence[str],
    weights: Optional[Sequence[float]] = None
) -> FactorSpec:
    """组合多个因子。
    
    Args:
        factor_names: 因子名称列表
        weights: 可选的权重列表，默认等权重
        
    Returns:
        组合因子的规格
    """
    if not weights:
        weights = [1.0 / len(factor_names)] * len(factor_names)
    
    name = "combined_" + "_".join(factor_names)
    window = max(
        spec.window
        for spec in DEFAULT_FACTORS
        if spec.name in factor_names
    )
    
    return FactorSpec(name, window)


def _normalize_universe(universe: Optional[Sequence[str]]) -> Optional[Tuple[str, ...]]:
    if not universe:
        return None
    unique: Dict[str, None] = {}
    for code in universe:
        value = (code or "").strip().upper()
        if value:
            unique.setdefault(value, None)
    return tuple(unique.keys()) if unique else None


def _has_factor_column(conn, column: str) -> bool:
    rows = conn.execute("PRAGMA table_info(factors)").fetchall()
    available = {row["name"] for row in rows}
    return column in available


def _list_factor_dates(conn, start: str, end: str, universe: Optional[Tuple[str, ...]]) -> List[str]:
    params: List[str] = [start, end]
    query = (
        "SELECT DISTINCT trade_date FROM factors "
        "WHERE trade_date BETWEEN ? AND ?"
    )
    if universe:
        placeholders = ",".join("?" for _ in universe)
        query += f" AND ts_code IN ({placeholders})"
        params.extend(universe)
    query += " ORDER BY trade_date"
    rows = conn.execute(query, params).fetchall()
    return [row["trade_date"] for row in rows if row and row["trade_date"]]


def _fetch_factor_cross_section(
    conn,
    column: str,
    trade_date: str,
    universe: Optional[Tuple[str, ...]],
) -> Dict[str, float]:
    params: List[str] = [trade_date]
    query = f"SELECT ts_code, {column} AS value FROM factors WHERE trade_date = ? AND {column} IS NOT NULL"
    if universe:
        placeholders = ",".join("?" for _ in universe)
        query += f" AND ts_code IN ({placeholders})"
        params.extend(universe)
    rows = conn.execute(query, params).fetchall()
    result: Dict[str, float] = {}
    for row in rows:
        ts_code = row["ts_code"]
        value = row["value"]
        if ts_code is None or value is None:
            continue
        try:
            numeric = float(value)
        except (TypeError, ValueError):
            continue
        if not np.isfinite(numeric):
            continue
        result[ts_code] = numeric
    return result


def _next_trade_date(conn, trade_date: str) -> Optional[str]:
    row = conn.execute(
        "SELECT MIN(trade_date) AS next_date FROM daily WHERE trade_date > ?",
        (trade_date,),
    ).fetchone()
    next_date = row["next_date"] if row else None
    return next_date


def _fetch_close_map(conn, trade_date: str, codes: Sequence[str]) -> Dict[str, float]:
    if not codes:
        return {}
    placeholders = ",".join("?" for _ in codes)
    params = [trade_date, *codes]
    rows = conn.execute(
        f"""
        SELECT ts_code, close
        FROM daily
        WHERE trade_date = ?
          AND ts_code IN ({placeholders})
          AND close IS NOT NULL
        """,
        params,
    ).fetchall()
    result: Dict[str, float] = {}
    for row in rows:
        ts_code = row["ts_code"]
        value = row["close"]
        if ts_code is None or value is None:
            continue
        try:
            result[ts_code] = float(value)
        except (TypeError, ValueError):
            continue
    return result


def _estimate_turnover_rate(
    factor_name: str,
    trade_dates: Sequence[str],
    universe: Optional[Tuple[str, ...]],
) -> Optional[float]:
    if not trade_dates:
        return None
    turnovers: List[float] = []
    for idx in range(1, len(trade_dates)):
        prev_date = trade_dates[idx - 1]
        curr_date = trade_dates[idx]
        with db_session(read_only=True) as conn:
            prev_map = _fetch_factor_cross_section(conn, factor_name, prev_date, universe)
            curr_map = _fetch_factor_cross_section(conn, factor_name, curr_date, universe)

        if not prev_map or not curr_map:
            continue

        prev_threshold = np.percentile(list(prev_map.values()), 80)
        curr_threshold = np.percentile(list(curr_map.values()), 80)
        prev_top = {code for code, value in prev_map.items() if value >= prev_threshold}
        curr_top = {code for code, value in curr_map.items() if value >= curr_threshold}
        if not prev_top and not curr_top:
            continue
        union = prev_top | curr_top
        if not union:
            continue
        turnover = len(prev_top ^ curr_top) / len(union)
        turnovers.append(turnover)

    if turnovers:
        return float(np.mean(turnovers))
    return None
